import numpy as np
import sys
from scipy.optimize import minimize
from regression import RegularizedMNLRegression
from numpy.linalg import inv

def v_norm(x: np.ndarray, V: np.ndarray) -> np.ndarray:
    """
    Computes ||x||_V for each row in x.

    Args:
        x (np.ndarray): (N, d) input matrix.
        V (np.ndarray): (d, d) positive definite matrix.

    Returns:
        np.ndarray: (N,) array of V-norm values.
    """
    return np.sqrt(np.einsum('ij,jk,ik->i', x, V, x))

class Agent:

    def __init__(self, **kwargs):
        self.__dict__.update(kwargs) 
        self.theta = np.zeros(self.d)

    def observe(self, x, t):
        pass
    
    def choose_action(self, recommend):
        pass

    def update(self, X, Y, stop):
        pass

class C3UCB(Agent):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.X = np.zeros((1, self.d))
        self.Y = np.zeros(1)[np.newaxis, ...]

        self.V = np.eye(self.d) * self.lam

        # self.theta = np.zeros(self.d)
        # self.beta = 0.0
        self.beta = np.sqrt(self.d * np.log(1 + (self.T * self.K) / (self.lam * self.d) ) + 4 * np.log(self.T)) + np.sqrt(self.lam) * self.B
        self.beta *= self.C

    def observe(self, x, t):
        self.round = t
        means = np.dot(x, self.theta).squeeze()
        self.xv = v_norm(x, inv(self.V))
        self.U = means + self.beta * self.xv

    def choose_action(self, recommend):
        self.action = recommend
        return self.action
    
    def update(self, X, Y, stop):

        ### Record Historic data ###
        observedX = X.reshape(-1, self.d) # X: ((stop+1), 1, d) -> observedX: ((stop+1), d)
        observedY = Y[:,:-1].reshape(-1, 1) # Y: ((stop+1), 2) -> observedY: ((stop+1), 1)
        self.X = np.concatenate((self.X, observedX))
        self.Y = np.concatenate((self.Y, observedY))
        if self.round == 1:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)
        
        ### Update Gram matrix ###
        self.V += sum([np.outer(observedX[k], observedX[k]) for k in range(stop+1)])

        ### Update Theta & Confidence Radius ###
        self.theta = np.linalg.inv(self.V).dot(self.X.T.dot(self.Y))
        
class UCBCCA(Agent):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.X = np.zeros((1, self.d))[np.newaxis, ...]
        self.Y = np.zeros(2)[np.newaxis, ...]
        self.V = np.eye(self.d) * self.lam
    
        self.regression = RegularizedMNLRegression()

        self.beta = (1 / self.kappa) * np.sqrt(self.d * np.log(1 + ((self.T * self.K) / self.d)) + 4 * np.log(self.T))
        self.beta *= self.C

    def observe(self, x, t):
        self.round = t
        
        means = np.dot(x, self.theta).squeeze()
        self.xv = v_norm(x, inv(self.V))
        self.U = means + self.beta * self.xv

    def choose_action(self, recommend):
        uncertain_item_idx = np.argsort(self.xv[recommend])[::-1][0]

        save = recommend[0]

        recommend[0] = recommend[uncertain_item_idx]
        recommend[uncertain_item_idx] = save 

        self.action = recommend
        return self.action
    
    def update(self, X, Y, stop):

        ### Record Historic data ###
        observedX = X # X: ((stop+1), 1, d) 
        observedY = Y # Y: ((stop+1), 2) 
        self.X = np.concatenate((self.X, observedX))
        self.Y = np.concatenate((self.Y, observedY))
        if self.round == 1:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)

        # Update Gram matrix
        self.V += np.matmul(observedX.transpose(0, 2, 1), observedX).sum(axis=0)
        
        ### Update Theta & Confidence Radius ###
        self.regression.fit(self.theta, self.X, self.Y, self.lam)
        self.theta = self.regression.w
    
class CLogUCB(Agent):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.X = np.zeros((1, self.d))[np.newaxis, ...]
        self.Y = np.zeros(2)[np.newaxis, ...]
        self.V = np.eye(self.d) * self.lam

        self.regression = RegularizedMNLRegression()
        self.beta = (self.B**2 + self.B + 19/4) * np.sqrt((self.d/self.kappa) * np.log(4 * (1+self.K * self.T)) + 2 * (self.d/self.kappa) * np.log(self.T) )
        self.beta *= self.C

    def observe(self, x, t):
        self.round = t

        means = np.dot(x, self.theta).squeeze()
        prob = np.exp(means) / (1 + np.exp(means)) # (N,)

        self.xv = v_norm(x, inv(self.V))
        self.U = prob + (1/4) * self.beta * self.xv

        self.action = np.argsort(self.U)[::-1][:self.K]

    def choose_action(self, recommend):
        # self.action = recommend
        return self.action

    def update(self, X, Y, stop):

        ### Record Historic data ###
        observedX = X # X: ((stop+1), 1, d) 
        observedY = Y # Y: ((stop+1), 2) 
        self.X = np.concatenate((self.X, observedX))
        self.Y = np.concatenate((self.Y, observedY))
        if self.round == 1:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)

        # Update Gram matrix
        self.V += np.matmul(observedX.transpose(0, 2, 1), observedX).sum(axis=0)

        ### Update Theta & Confidence Radius ###
        self.regression.fit(self.theta, self.X, self.Y, self.lam)
        self.theta = self.regression.w

class UCBCLB(Agent):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.eta = (1/2) * np.log(2) + self.B + 1
        self.beta = self.B * np.sqrt(self.d) * np.log(self.T * self.K) * np.log(2) + self.B ** (3/2) * np.sqrt(self.d * np.log(2)) 
        self.beta *= self.C
        self.H = np.eye(self.d) * self.lam
    
    def observe(self, x, t):
        self.round = t

        self.means = np.dot(x, self.theta).squeeze()
        self.xh = v_norm(x, inv(self.H))
        self.U = self.means + self.beta * self.xh

    def choose_action(self, recommend):

        recommend_means = self.means[recommend]
        gradients = np.exp(recommend_means) / (1 + np.exp(recommend_means)) ** 2
        widths = self.xh[recommend]
        first_position = np.argsort(gradients * widths)[::-1][0]
        second_position = np.argsort(widths)[::-1][0]

        if first_position == second_position:
            save = recommend[0]
            recommend[0] = recommend[first_position]
            recommend[first_position] = save
            
            new_recommend = recommend
        else:
            remaining_elements = np.delete(recommend, [first_position, second_position])
            
            new_recommend = np.concatenate(([recommend[first_position]], [recommend[second_position]], remaining_elements))

        self.action =new_recommend

        return self.action
    
    def update(self, X, Y, stop):

        ### Record Historic data ###
        observedX = X.reshape(-1, self.d) # X: ((stop+1), 1, d) -> observedX: ((stop+1), d)
        observedY = Y[:,:-1].reshape(-1) # Y: ((stop+1), 2) -> observedY: ((stop+1), 1

        # ### Update Gram matrix ###
        def sigmoid(x, theta):
            mean = np.dot(x, theta)
            nu = np.exp(mean)
            de = 1 + np.exp(mean)
            return nu / de

        def sigmoid_gradient(x, theta):
            mean = np.dot(x, theta)
            nu = np.exp(mean)
            de = 1 + np.exp(mean)
            return (nu / de) * (1 / de)
        
        def proj_fun(W, un_projected, M):
            diff = W-un_projected
            fun = np.sqrt(np.dot(diff, np.dot(M, diff)))
            return fun
        
        def solve_argmin_theta(unprojected, M):
            fun = lambda x: proj_fun(x, unprojected, M)
            constraints = {'type': 'ineq', 'fun': lambda x: self.B - np.linalg.norm(x)}
            opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraints)
            # opt = minimize(fun, x0=unprojected, method='SLSQP', constraints=constraints)
            return opt.x

        theta_tk = self.theta
        
        H_tk = self.H
        H_tk_tilde = self.H

        for k in range(stop+1):

            H_tk_tilde = H_tk + self.eta * sigmoid_gradient(observedX[k], theta_tk) * np.outer(observedX[k], observedX[k])

            inv_H = np.linalg.inv(H_tk_tilde)
            unprojected_update = theta_tk - self.eta * np.dot(inv_H, (sigmoid(observedX[k], theta_tk) - observedY[k]) * observedX[k] ) 
            
            if np.linalg.norm(unprojected_update) > self.B * 1.2:
                theta_tk = solve_argmin_theta(unprojected_update, H_tk_tilde)
            else:
                theta_tk = unprojected_update

            H_tk = H_tk + sigmoid_gradient(observedX[k], theta_tk) * np.outer(observedX[k], observedX[k])
        
        self.theta = theta_tk
        self.H = H_tk

def create_agent(**kwargs):

    agent_name = kwargs['algorithm']
    
    AgentClass = getattr(sys.modules[__name__], agent_name)

    print(AgentClass, flush=True)

    agent = AgentClass(**kwargs)

    return agent


if __name__ == "__main__":
    pass

